Conditional Probability

Author

Prof. Calvin

Published

February 17, 2025

Abstract:

0. Quarto Type-setting

  • This document is rendered with Quarto, and configured to embed an images using the embed-resources option in the header.
  • If you wish to use a similar header, here’s is the format specification for this document:

1. Setup

sh <- suppressPackageStartupMessages
sh(library(tidyverse))
sh(library(caret))
sh(library(naivebayes)) # bae caught me naivin'
sh(library(tidytext))
wine <- readRDS(gzcon(url("https://github.com/cd-public/D505/raw/master/dat/pinot.rds")))

2. Conditional Probability

  • Calculate the probability that a Pinot comes from Burgundy given it has the word ‘fruit’ in the description.
    • Take \(A\) to be the probability that a Pinot was grown in Burgundy.
    • Take \(B\) to be the probability that Pinot has the word ‘fruit’ in it’s description.

\[ P(A|B) \]

nrow(filter(wine,province=="Burgundy" & str_detect(description,"fruit")))/nrow(filter(wine, str_detect(description,"fruit")))
[1] 0.2196038

3. Naive Bayes Algorithm

  • We train a naive bayes algorithm to classify a wine’s province using:
  1. An 80-20 train-test split.
  2. Three features engineered from the description
  3. 5-fold cross validation.
  • We report Kappa after using the model to predict provinces in the holdout sample.
wino = wine %>% 
  mutate(cherry = str_detect(description,"cherry")) %>% 
  mutate(chocolate = str_detect(description,"chocolate")) %>%
  mutate(earth = str_detect(description,"earth")) %>%
  select(-description)

wine_index <- createDataPartition(wino$province, p = 0.80, list = FALSE)
train <- wino[ wine_index, ]
test <- wino[-wine_index, ]

fit <- train(province ~ .,
             data = train, 
             method = "naive_bayes",
             metric = "Kappa",
             trControl = trainControl(method = "cv", number = 5))

confusionMatrix(predict(fit, test),factor(test$province))
Confusion Matrix and Statistics

                   Reference
Prediction          Burgundy California Casablanca_Valley Marlborough New_York
  Burgundy               171         77                 8          13        5
  California              64        703                16          24       15
  Casablanca_Valley        0          0                 0           0        0
  Marlborough              0          0                 0           0        0
  New_York                 0          0                 0           0        2
  Oregon                   3         11                 2           8        4
                   Reference
Prediction          Oregon
  Burgundy             146
  California           305
  Casablanca_Valley      0
  Marlborough            0
  New_York               0
  Oregon                96

Overall Statistics
                                          
               Accuracy : 0.581           
                 95% CI : (0.5569, 0.6048)
    No Information Rate : 0.4728          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.3258          
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: Burgundy Class: California Class: Casablanca_Valley
Sensitivity                   0.7185            0.8887                  0.00000
Specificity                   0.8265            0.5193                  1.00000
Pos Pred Value                0.4071            0.6238                      NaN
Neg Pred Value                0.9465            0.8388                  0.98446
Prevalence                    0.1423            0.4728                  0.01554
Detection Rate                0.1022            0.4202                  0.00000
Detection Prevalence          0.2510            0.6736                  0.00000
Balanced Accuracy             0.7725            0.7040                  0.50000
                     Class: Marlborough Class: New_York Class: Oregon
Sensitivity                      0.0000        0.076923       0.17550
Specificity                      1.0000        1.000000       0.97513
Pos Pred Value                      NaN        1.000000       0.77419
Neg Pred Value                   0.9731        0.985637       0.70884
Prevalence                       0.0269        0.015541       0.32696
Detection Rate                   0.0000        0.001195       0.05738
Detection Prevalence             0.0000        0.001195       0.07412
Balanced Accuracy                0.5000        0.538462       0.57532

4. Frequency Differences

  • We find the three words that most distinguish New York Pinots from all other Pinots.

Calculate relative word count.

wc <- function(df, omits) {
  count <- nrow(df)
  df %>%
    unnest_tokens(word, description) %>% anti_join(stop_words) %>%
    filter(!(word %in% omits)) %>% 
    group_by(word) %>% mutate(total=n()) %>% count() %>%
    ungroup() %>% mutate(n=n/count)
}

Make corresponding dataframes.

omits = c("pinot", "noir", "wine")
wc_ny <- wc(wine %>% filter(province=="New_York") %>% select(description), omits)
Joining with `by = join_by(word)`
wc_no <- wc(wine %>% filter(province!="New_York") %>% select(description), omits)
Joining with `by = join_by(word)`

Calculate difference in relative frequencies.

diff <- wc_ny %>%
    inner_join(wc_no, by = "word", suffix = c("_ny", "_no")) %>%
    mutate(diff = n_ny - n_no) %>%
    arrange(desc(abs(diff)))
    
diff %>% head(3)
# A tibble: 3 × 4
  word     n_ny  n_no  diff
  <chr>   <dbl> <dbl> <dbl>
1 cherry  0.916 0.431 0.485
2 tannins 0.580 0.234 0.346
3 palate  0.565 0.239 0.326

Plot it.

sh(library(plotly))

plot_ly(diff %>% top_n(25, diff), 
        x = ~n_ny, y = ~n_no, z = ~diff, text = ~word, 
        type = 'scatter3d', mode = 'markers+text', 
        marker = list(size = 5, color = ~diff, showscale = TRUE),
        textposition = 'top right') %>%
    layout(title = "3D Scatterplot of Word Frequencies",
           scene = list(
               xaxis = list(title = "Frequency in New York Pinots"),
               yaxis = list(title = "Frequency in Other Pinots"),
               zaxis = list(title = "Difference in Frequency")
           ))